本日將簡單介紹Pytorch-Lightning,而在包含今日的未來幾天內,會將先前構築的程式碼,分段整合成Pytorch-Lightning的格式。本日的部份是forward propogation的部份。
Pytorch Lightning是一個標榜同時可以簡化工程作業量,又同時具備高擴充性的Pytorch相容框架。
其與Pytorch的關係,有點類似TensorFlow與Keras。
(註:筆者本人也用過一陣子Keras跟TensorFlow 2,Keras的操作更加簡易,但是Flexibility就不太令人滿意,尤其是要客製一些框架或是訓練策略的時候,反倒是TensorFlow 2還順手一些)
主要的概念跟作法,可以直接參考下列這個來自Pytorch Lightning官方文件 LIGHTNING IN 15 MINUTES的簡介影片:
簡單看完影片以後,相信大概能有個概念。現在來舉一個最簡單的例子,讓我們先來回顧前幾日的train.py裡頭每個epoch的training跟validation是怎麼做的?
Training Phase:
inputs, labels = batch['img'].to(device), batch['labels'].float().to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
Validation Phase:
for batch in pbar:
step += 1
val_images, val_labels = batch['img'].to(device), batch['labels'].to(device)
y_pred = torch.cat([y_pred, model(val_images)], dim=0)
y = torch.cat([y, val_labels], dim=0)
pbar.set_description('Validating ...')
y_prob = torch.nn.Sigmoid()(y_pred)
loss = loss_function(y_pred, y.float()).item()
有沒有發現驚人的重工之處?
基本上都是在forward propogation以後在計算loss,只是差在有沒有梯度下降的差異而已。而且其實重複的且類似但又不太一樣的程式碼,也是增加進行小修改時出錯的風險。
Pytorch Lightning 裡最核心的api就屬 Lightning Module,只要把模型整合成這個物件,基本上就可以開啟 Pytorch Lightning內的各種強大的支援。
根據文件中的內容,可以透過這個api把上一段落內的training跟validation大致整成下面的架構如下:
import pytorch_lightning as pl
class MultiLabelsModel(pl.LightningModule):
"""
Lightning Module of Multi-Labels Classification for ChestMNIST
"""
def __init__(self, CONFIG):
self.backbone = get_backbone(CONFIG)
... # 可以網羅各種的初始設定,通常我會把大部分的超參數
... # 還有一些實驗過程需要的額外物件放在這個地方
def forward(self, x):
y = self.backbone(x) # model inference 的主體,使用很自由
return y # 不論是要加層,增加input或output都可以簡單實現
def step(self, batch: Any):
inputs, labels = batch['img'].to(self.device), batch['labels'].to(self.device)
preds = self.forward(inputs)
loss = self.loss_function(preds, labels.float())
return inputs, preds, labels, loss
def training_step(self, batch: Any, batch_idx: int):
inputs, preds, labels, loss = self.step(batch)
return loss
def validation_step(self, batch: Any, batch_idx: int):
inputs, preds, labels, loss = self.step(batch)
return {
'preds' : outputs,
'labels' : labels
}
def validation_epoch_end(self, validation_step_outputs: List[Any]):
preds = torch.cat([output['preds'] for output in validation_step_outputs], dim=0)
labels = torch.cat([output['labels'] for output in validation_step_outputs], dim=0)
probs = torch.nn.Sigmoid()(preds)
# compute metrics and log
acc_score = torchmetrics.functional.accuracy(probs, labels, mdmc_average = 'global')
auc_score = monai.metrics.compute_roc_auc(probs, labels, average='macro')
...
透過呼叫共用的step
,就可以讓分別對應的training_step
與validation_step
都能實現與原先相同的forward propogation。而要蒐集整個validation結果,進而計算accuracy與auc的部份,則可以在validation_epoch_end
內,會自動將每個validation_step
的output作為input輸入,就可以計算整個驗證集的指標了。
如此切割各個功能後,除了可讀性上比較好一些,要debug也會比較容易一些,可說是好處多多。後續還有許多設計檔的瑣碎工作需要做,就讓我們挪到後續幾天再來慢慢完成!